{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第一题：使用sklearn的GaussianNB、BernoulliNB、MultinomialNB完成spambase垃圾邮件分类任务\n",
    "\n",
    "实验内容：\n",
    "1. 使用GaussianNB、BernoulliNB、MultinomialNB完成spambase邮件分类\n",
    "2. 计算各自十折交叉验证的精度、查准率、查全率、F1值\n",
    "3. 根据精度、查准率、查全率、F1值的实际意义以及四个值的对比阐述三个算法在spambase中的表现对比"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 读取数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "spambase = np.loadtxt('data/spambase/spambase.data', delimiter = \",\")\n",
    "spamx = spambase[:, :57]\n",
    "spamy = spambase[:, 57]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. 导入模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import GaussianNB\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn.naive_bayes import BernoulliNB\n",
    "from sklearn.model_selection import cross_val_predict\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import precision_score\n",
    "from sklearn.metrics import recall_score\n",
    "from sklearn.metrics import f1_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. 计算十折交叉验证下，GaussianNB、BernoulliNB、MultinomialNB的精度、查准率、查全率、F1值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GaussianNB在十折交叉验证下的四项指标\n",
      "精度: 0.82\n",
      "查准率: 0.7\n",
      "查全率: 0.96\n",
      "f1值: 0.81\n"
     ]
    }
   ],
   "source": [
    "# YOUR CODE HERE\n",
    "model = GaussianNB() # GaussianNB\n",
    "# YOUR CODE HERE\n",
    "\n",
    "prediction = cross_val_predict(model,spamx,spamy,cv = 10)\n",
    "acc1 = round(accuracy_score(spamy,prediction),2)\n",
    "precision1 = round(precision_score(spamy,prediction),2)\n",
    "recall1 = round(recall_score(spamy,prediction),2)\n",
    "f1 = round(f1_score(spamy,prediction),2)\n",
    "\n",
    "print(\"GaussianNB在十折交叉验证下的四项指标\")\n",
    "print(\"精度:\",acc1)\n",
    "print(\"查准率:\",precision1)\n",
    "print(\"查全率:\",recall1)\n",
    "print(\"f1值:\",f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BernoulliNB在十折交叉验证下的四项指标\n",
      "精度: 0.88\n",
      "查准率: 0.88\n",
      "查全率: 0.82\n",
      "f1值: 0.85\n"
     ]
    }
   ],
   "source": [
    "# YOUR CODE HERE\n",
    "model = BernoulliNB() # BernoulliNB\n",
    "# YOUR CODE HERE\n",
    "\n",
    "prediction = cross_val_predict(model,spamx,spamy,cv = 10)\n",
    "acc1 = round(accuracy_score(spamy,prediction),2)\n",
    "precision1 = round(precision_score(spamy,prediction),2)\n",
    "recall1 = round(recall_score(spamy,prediction),2)\n",
    "f1 = round(f1_score(spamy,prediction),2)\n",
    "\n",
    "print(\"BernoulliNB在十折交叉验证下的四项指标\")\n",
    "print(\"精度:\",acc1)\n",
    "print(\"查准率:\",precision1)\n",
    "print(\"查全率:\",recall1)\n",
    "print(\"f1值:\",f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MultinomialNB在十折交叉验证下的四项指标\n",
      "精度: 0.79\n",
      "查准率: 0.73\n",
      "查全率: 0.72\n",
      "f1值: 0.73\n"
     ]
    }
   ],
   "source": [
    "# YOUR CODE HERE\n",
    "model = MultinomialNB() # MultinomialNB\n",
    "# YOUR CODE HERE\n",
    "\n",
    "prediction = cross_val_predict(model,spamx,spamy,cv = 10)\n",
    "acc1 = round(accuracy_score(spamy,prediction),2)\n",
    "precision1 = round(precision_score(spamy,prediction),2)\n",
    "recall1 = round(recall_score(spamy,prediction),2)\n",
    "f1 = round(f1_score(spamy,prediction),2)\n",
    "\n",
    "print(\"MultinomialNB在十折交叉验证下的四项指标\")\n",
    "print(\"精度:\",acc1)\n",
    "print(\"查准率:\",precision1)\n",
    "print(\"查全率:\",recall1)\n",
    "print(\"f1值:\",f1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### 双击此处填写\n",
    "算法|精度|查准率|查全率|F1值\n",
    "-|-|-|-|-\n",
    "GaussianNB|0.82|0.7|0.96|0.81\n",
    "MultinomialNB|0.79|0.73|0.72|0.73\n",
    "BernoulliNB|0.88|0.88|0.82|0.85 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "分析如下：\n",
    "\n",
    "**GaussianNB算法在spambase中，查全率高，即其能对正例样本有很好的分类覆盖**\n",
    "\n",
    "**MultinomialNB算法在spambase中，不论是精度还是f1值均不高，表现不是很好**\n",
    "\n",
    "**BernoulliNB算法在spambase中，精度高，虽然查全率没有GaussianNB高，但是其查准率也高，f1值总体高**"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
